
from datasets import load_dataset,load_from_disk
from torch.utils.data import Dataset
#数据加载
data_cache = "./data_cache"
#dataset = load_dataset("lansinuote/ChnSentiCorp",cache_dir=data_cache,split="train") #在线下载加载
#dataset.to_csv(path_or_buf=r"E:\nlp\data_cache\ChnSentiCorp.csv") #转换为csv格式
dataset2 = load_dataset(path="csv",data_files="/home/ubuntu/jack/python/nlp/data_cache/ChnSentiCorp.csv")#从本地csv格式加载


#数据类

class MyData(Dataset):
    def __init__(self,slipt):
        super(MyData,self).__init__()
        self.data = dataset2
        if slipt=="train":
            self.data =  self.data["train"]
        elif slipt=="test":
            self.data =  self.data["test"]
        elif slipt=="val":
            self.data =  self.data["val"]
        
    def __getitem__(self,index):
        return self.data[index]["text"],self.data[index]["label"]
    
    def __len__(self):
        return len(self.data)
    
if __name__ == "__main__":
    data = MyData("train")
    for d in data:
        print(d)
        
